import torch

a = torch.rand(4, 32, 8)
b = torch.rand(5, 32, 8)
c = torch.rand(4, 32, 8)

# 1、cat拼接，dim指定拼接的维度
print(torch.cat([a, b], dim=0).shape)
# 2、stack拼接，会创建一个新的维度，用于区分组合后的两个tensor
print(torch.stack([a, c], dim=0).shape)

# 3、split拆分
_1, _2 = a.split(2, dim=0)  # 0维度拆开，没两个一拆
print(_1.shape, _2.shape)
_1, _2, _3 = a.split([1, 2, 1], dim=0)  # 0维度拆开，拆的长度分别为1，2，1
print(_1.shape, _2.shape, _3.shape)

# 4、chunk拆分
_1, _2 = a.chunk(2, dim=0)
print(_1.shape, _2.shape)
